from sfi import Data
from sfi import SFIToolkit as stata
import numpy as np
import pandas as p
import matplotlib.pyplot as plt
import sys

#Put neighboring states into a list
def getNeighborStates(statecode, neighborData):
    newData = neighborData[neighborData['State'] == statecode]
    states = newData.iloc[0, 1]
    if (states == states):
        states = states.split(',')
    else:
        states = []
    return states

#Determines the sum number of neighbor states without safety inspections for a given year
def getSumStatesWithoutPolicy(statecode, year, data, neighborData):
    numStates = 0
    year = float(year)
    neighborStates = getNeighborStates(statecode, neighborData)

    has51 = False
    startYear = 1970
    while(not has51):
        tempSet = set(data[data['year']==str(startYear)]['statecode'])
        if(len(tempSet)==51):
            has51 = True
        elif(startYear==2017):
            print('No year has all states')
        else:
            startYear+=1

    newData = data[data['year'] == str(startYear)] #2017 is a year where all states are present
    newData = newData[newData['statecode'].isin(neighborStates)]

    #loop over the neighbor states
    numStates = 0
    for index, row in newData.iterrows():
        if(row['safetyclass']=='Never Had Safety Inspections'):
            numStates+=1
        else:
            if(float(row['endedsafety'])!=float(row['endedsafety'])):
                tempStart = float(row['startedsafety'])
                tempEnd = np.inf
            else:
                tempStart = float(row['startedsafety'])
                tempEnd = float(row['endedsafety'])

            if(not(year>=tempStart and year<tempEnd)):
                numStates+=1


    return numStates

#Determines the change in number treated from minYear to year
def getChangeInSumStatesWithoutPolicy(statecode, year, minYear, data, neighborData):
    if (int(year) < int(minYear)):
        return np.nan
    else:
        startVal = getSumStatesWithoutPolicy(statecode, minYear, data, neighborData)
        finalVal = getSumStatesWithoutPolicy(statecode, year, data, neighborData)

        return finalVal - startVal


#Load data from stata frame
cols = ['statecode', 'year', 'nosafetyind', 'safetyclass', 'endedsafety', 'startedsafety']

data = p.DataFrame(np.array(Data.get(cols, missingval=np.nan)), columns=cols)

#Load data desc ribing neighboring states
if(sys.platform=='darwin'):
    neighborFile = 'PythonScripts/Data/NeighboringStateData/cleanedNeighborStates.csv'
else:
    neighborFile = r"PythonScripts\Data\NeighboringStateData\cleanedNeighborStates.csv"
neighborData = p.read_csv(neighborFile)

minYear = 1970  # The year in which our analysis begins

rows, cols = data.shape
statecodeCol = 0
yearCol = 1

delNeighborNoSafetyCol = []

for i in range(rows):
    tempState = data.iloc[i, statecodeCol]
    tempYear = data.iloc[i, yearCol]

    tempDel = getChangeInSumStatesWithoutPolicy(tempState, tempYear, minYear, data, neighborData)
    delNeighborNoSafetyCol.append(tempDel)

#Store the new variable
Data.addVarDouble("delNeighborNoSafety")
Data.store("delNeighborNoSafety", None, np.array(delNeighborNoSafetyCol))